{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "\n",
    "#!/usr/bin/env python\n",
    "# coding: utf-8\n",
    "\n",
    "import os\n",
    "import math\n",
    "import time\n",
    "import json\n",
    "import random\n",
    "\n",
    "from collections import OrderedDict\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "from data.data_iterator import TextIterator\n",
    "\n",
    "import data.util as util\n",
    "import data.data_utils as data_utils\n",
    "from data.data_utils import prepare_batch\n",
    "from data.data_utils import prepare_train_batch\n",
    "\n",
    "from seq2seq_model import Seq2SeqModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Decoding parameters\n",
    "tf.app.flags.DEFINE_integer('beam_width', 12, 'Beam width used in beamsearch')\n",
    "tf.app.flags.DEFINE_integer('decode_batch_size', 80, 'Batch size used for decoding')\n",
    "tf.app.flags.DEFINE_integer('max_decode_step', 500, 'Maximum time step limit to decode')\n",
    "tf.app.flags.DEFINE_boolean('write_n_best', False, 'Write n-best list (n=beam_width)')\n",
    "tf.app.flags.DEFINE_string('model_path', None, 'Path to a specific model checkpoint.')\n",
    "tf.app.flags.DEFINE_string('decode_input', 'data/newstest2012.bpe.de', 'Decoding input path')\n",
    "tf.app.flags.DEFINE_string('decode_output', 'data/newstest2012.bpe.de.trans', 'Decoding output path')\n",
    "\n",
    "# Runtime parameters\n",
    "tf.app.flags.DEFINE_boolean('allow_soft_placement', True, 'Allow device soft placement')\n",
    "tf.app.flags.DEFINE_boolean('log_device_placement', False, 'Log placement of ops on devices')\n",
    "\n",
    "FLAGS = tf.app.flags.FLAGS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def load_config(FLAGS):\n",
    "    \n",
    "    config = util.unicode_to_utf8(\n",
    "        json.load(open('%s.json' % FLAGS.model_path, 'rb')))\n",
    "    for key, value in FLAGS.__flags.items():\n",
    "        config[key] = value\n",
    "\n",
    "    return config\n",
    "\n",
    "\n",
    "def load_model(session, config):\n",
    "    \n",
    "    model = Seq2SeqModel(config, 'decode')\n",
    "    if tf.train.checkpoint_exists(FLAGS.model_path):\n",
    "        print 'Reloading model parameters..'\n",
    "        model.restore(session, FLAGS.model_path)\n",
    "    else:\n",
    "        raise ValueError(\n",
    "            'No such file:[{}]'.format(FLAGS.model_path))\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def decode():\n",
    "    # Load model config\n",
    "    config = load_config(FLAGS)\n",
    "\n",
    "    # Load source data to decode\n",
    "    test_set = TextIterator(source=config['decode_input'],\n",
    "                            batch_size=config['decode_batch_size'],\n",
    "                            source_dict=config['source_vocabulary'],\n",
    "                            maxlen=None,\n",
    "                            n_words_source=config['num_encoder_symbols'])\n",
    "\n",
    "    # Load inverse dictionary used in decoding\n",
    "    target_inverse_dict = data_utils.load_inverse_dict(config['target_vocabulary'])\n",
    "    \n",
    "    # Initiate TF session\n",
    "    with tf.Session(config=tf.ConfigProto(allow_soft_placement=FLAGS.allow_soft_placement, \n",
    "        log_device_placement=FLAGS.log_device_placement, gpu_options=tf.GPUOptions(allow_growth=True))) as sess:\n",
    "\n",
    "        # Reload existing checkpoint\n",
    "        model = load_model(sess, config)\n",
    "        try:\n",
    "            print 'Decoding {}..'.format(FLAGS.decode_input)\n",
    "            if FLAGS.write_n_best:\n",
    "                fout = [data_utils.fopen((\"%s_%d\" % (FLAGS.decode_output, k)), 'w') \\\n",
    "                        for k in range(FLAGS.beam_width)]\n",
    "            else:\n",
    "                fout = [data_utils.fopen(FLAGS.decode_output, 'w')]\n",
    "            \n",
    "            for idx, source_seq in enumerate(test_set):\n",
    "                source, source_len = prepare_batch(source_seq)\n",
    "                # predicted_ids: GreedyDecoder; [batch_size, max_time_step, 1]\n",
    "                # BeamSearchDecoder; [batch_size, max_time_step, beam_width]\n",
    "                predicted_ids = model.predict(sess, encoder_inputs=source, \n",
    "                                              encoder_inputs_length=source_len)\n",
    "                   \n",
    "                # Write decoding results\n",
    "                for k, f in reversed(list(enumerate(fout))):\n",
    "                    for seq in predicted_ids:\n",
    "                        f.write(str(data_utils.seq2words(seq[:,k], target_inverse_dict)) + '\\n')\n",
    "                    if not FLAGS.write_n_best:\n",
    "                        break\n",
    "                print '  {}th line decoded'.format(idx * FLAGS.decode_batch_size)\n",
    "                \n",
    "            print 'Decoding terminated'\n",
    "        except IOError:\n",
    "            pass\n",
    "        finally:\n",
    "            [f.close() for f in fout]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def main(_):\n",
    "    decode()\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    tf.app.run()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
